#!/usr/bin/env python
# _*_coding:utf-8_*_

"""
@Software: PyCharm
@Author:  zhaojianghua
@Email: zhaojianghua1990@qq.com
@Homepage: None
"""

import random
import copy
import collections
import bisect

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np


def swish(x):
    return x / (1 + torch.exp(-x))


class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 64)
        self.fc2 = nn.Linear(64, 16)
        self.fc3 = nn.Linear(16, 2)

    def forward(self, inputs):
        # x = F.relu(self.fc1(inputs))
        # x = F.relu(self.fc2(x))
        x = swish(self.fc1(inputs))
        x = swish(self.fc2(x))
        return self.fc3(x)


class ReplayMemory(object):

    def __init__(self, max_size=1024):
        self.replays = collections.deque(maxlen=max_size)

    def append(self, x):
        self.replays.append(x)

    def sample(self, batch_size=64):
        idx = random.sample(list(range(len(self))), batch_size)
        obs_batch = np.asarray([self.replays[i][0] for i in idx], dtype=np.float32)
        action_batch = np.asarray([self.replays[i][1] for i in idx], dtype=np.float32)
        reward_batch = np.asarray([self.replays[i][2] for i in idx], dtype=np.float32)
        next_obs_batch = np.asarray([self.replays[i][3] for i in idx], dtype=np.float32)
        done_batch = np.asarray([self.replays[i][4] for i in idx], dtype=np.float32)
        return obs_batch, action_batch, reward_batch, next_obs_batch, done_batch

    def __len__(self):
        return len(self.replays)


class RewardSlotReplayMemory(object):

    def __init__(self, reward_slots, max_size=1024):
        self.reward_slots = reward_slots
        # the last deque used for default
        sizes = [max_size//len(self.reward_slots)] * len(self.reward_slots)
        sizes.append(max_size)
        sizes.append(max_size)
        self.replays = [collections.deque(maxlen=sizes[i])
                        for i in range(len(reward_slots) + 2)]

    def append(self, x):
        self.replays[-1].append(x)

    def move_to_slot(self, reward):
        i = bisect.bisect_right(self.reward_slots, reward)
        while self.replays[-1]:
            self.replays[i].append(self.replays[-1].popleft())

    def sample(self, batch_size=64):
        idx = random.sample(list(range(len(self))), batch_size)
        obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = \
            [], [], [], [], []
        for i in idx:
            j = 0
            while i >= len(self.replays[j]):
                i -= len(self.replays[j])
                j += 1
            obs_batch.append(self.replays[j][i][0])
            action_batch.append(self.replays[j][i][1])
            reward_batch.append(self.replays[j][i][2])
            next_obs_batch.append(self.replays[j][i][3])
            done_batch.append(self.replays[j][i][4])

        obs_batch = np.asarray(obs_batch, dtype=np.float32)
        action_batch = np.asarray(action_batch, dtype=np.float32)
        reward_batch = np.asarray(reward_batch, dtype=np.float32)
        next_obs_batch = np.asarray(next_obs_batch, dtype=np.float32)
        done_batch = np.asarray(done_batch, dtype=np.float32)
        return obs_batch, action_batch, reward_batch, next_obs_batch, done_batch

    def __len__(self):
        return sum(len(x) for x in self.replays)


class DQN(object):

    def __init__(self, env):
        self.env = env
        self.net = Model()
        self.target_net = copy.deepcopy(self.net)
        self.target_net.eval()

        self.optimizer = optim.SGD(self.net.parameters(),
                                   lr=0.001,
                                   momentum=0.9)
        self.gama = 0.95
        self.explore_epsilon = 0.1
        self.replays = ReplayMemory(max_size=4096)
        # self.replays = RewardSlotReplayMemory(np.linspace(0, 200, 5), max_size=2048)

    def action(self, observation, explore=True, pr=False):
        if explore and random.random() < self.explore_epsilon:
            return self.env.action_space.sample()
        else:
            with torch.no_grad():
                self.net.eval()
                inputs = torch.tensor([observation], dtype=torch.float32)
                q_values = self.net(inputs)
                if pr:
                    print(inputs.detach().numpy(), q_values.detach().numpy())
                self.net.train()
                return q_values.argmax().numpy()

    def update(self, observation, action, reward, observation_next, done):
        self.replays.append((observation, action, reward, observation_next, done))
        obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = self.replays.sample()

        self.optimizer.zero_grad()
        q_values = self.net(torch.tensor(obs_batch))
        with torch.no_grad():
            next_q_values = self.target_net(torch.tensor(next_obs_batch))
            q = next_q_values.max(dim=1)[0]
            target = self.gama * q * (1 - torch.tensor(done_batch)) + reward_batch

        action_batch = np.stack([1-action_batch, action_batch], axis=-1)
        value = (q_values * torch.tensor(action_batch)).sum(dim=1)
        loss = 0.5 * torch.mean((target - value) ** 2)
        # if loss.detach().numpy() > 10:
        #     print(loss.detach().numpy())
        loss.backward()
        self.optimizer.step()

    def sync_weights(self):
        self.target_net.load_state_dict(self.net.state_dict())
        # del self.optimizer
        # self.optimizer = optim.SGD(self.net.parameters(),
        #                            lr=0.001,
        #                            momentum=0.9)


def evaluate(agent, times=10, pr_final=False):
    env = gym.make("CartPole-v0")
    total_reward_list = []
    for i in range(times):
        x = env.reset()
        total_reward = 0
        while True:
            action = agent.action(x, explore=False, pr=(i == times-1 and pr_final))
            x, reward, done, info = env.step(action)
            total_reward += reward
            if done:
                break
        total_reward_list.append(total_reward)
    return sum(total_reward_list) / times


def train():
    env = gym.make("CartPole-v0")
    print(env.metadata)

    agent = DQN(env)

    # 提前保存回放
    for i in range(200):
        obs = env.reset()
        total_reward = 0
        while True:
            act = env.action_space.sample()
            obs_nxt, reward, done, info = env.step(act)
            agent.replays.append((obs, act, reward, obs_nxt, 1 if done else 0))
            obs = obs_nxt
            total_reward += reward
            if done:
                # agent.replays.move_to_slot(total_reward)
                break

    episode = 2000
    steps = 0
    for i in range(episode):
        observation = env.reset()
        total_reward = 0
        done_count = 0
        while True:
            steps += 1
            action = agent.action(observation)
            # print(action)
            observation_next, reward, done, info = env.step(action)  # 与环境交互，获得下一步的时刻
            # done_value = 1 if done else 0
            done_value = 1 if done and not info.get('TimeLimit.truncated', False) else 0
            if steps % 200 == 0:
                agent.sync_weights()
            agent.update(observation, action, reward, observation_next, done_value)
            observation = observation_next
            total_reward += reward
            if done:
                # agent.replays.move_to_slot(total_reward)
                break
        print("episode:", i, "steps:", steps, "total_reward:", total_reward)
        if (i + 1) % 50 == 0:
            print("episode:", i, "evaluate:", evaluate(agent, 10))

    print("Final evaluate:", evaluate(agent, 100, pr_final=True))


if __name__ == "__main__":
    train()
